#coding=utf-8
# Copyright (c) 2016 Tinydot. inc.
# All Rights Reserved.
#
#    Licensed under the Apache License, Version 2.0 (the "License"); you may
#    not use this file except in compliance with the License. You may obtain
#    a copy of the License at
#
#         http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
#    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
#    License for the specific language governing permissions and limitations
#    under the License.
import tensorflow as tf
import numpy as np
from google.protobuf import text_format
from tensorflow.core.framework import types_pb2 ##中间生成
from ml_feature.util.Operation import load_pbfile
import codecs
import os

def get_tensor_content_str(dtype, shape):
    """make tensor pb string"""
    #tf_pb = tf.make_tensor_proto(tf.ones(shape=shape, dtype=dtype).eval(), dtype=dtype, shape=shape)
    tf_pb = tf.make_tensor_proto(np.ones(shape,dtype=dtype),shape=shape)
    txt = str(tf_pb)
    key = "tensor_content:"
    pos = txt.find(key) + len(key)
    return txt[pos:].strip()
def load_pbtxt_file(url,flag=True):
    '''
    Parameters:
        flag: True 返回graph_def, False返回graph
    '''
    graph_def = tf.GraphDef()
    with open(url, 'rb') as f:
        text_format.Merge(f.read(), graph_def)
    if flag:
        return graph_def
    else:
        graph = tf.Graph()
        with graph.as_default():
            tf.import_graph_def(graph_def,name='')
        return graph
def get_dtype_index_name_dic():
    ret = {}
    for d in types_pb2._DATATYPE.values:
        ret[d.number] = d.name
    return ret
def clean_string(dirty_string):
    return dirty_string.strip().replace('\n','')
def get_indent_string(num):
    ans = ''
    for index in range(num):
        ans += '\t'
    return ans
def find_in_out(index,name,name_inputs):
    in_node = []
    out_node = []
    for key in name_inputs.keys():
        if key != name[index]:
            if key in name_inputs[name[index]]:
                in_node.append(key)
            if name[index] in name_inputs[key]:
                out_node.append(key)
    return in_node, out_node


def getOutputshape(graph_def, name):
    '''
    获取graph_def中和name名字一样的节点的输出样式
    :param graph_def: 存储了节点信息的模型结构
    :param name: 目的节点的名字
    :return: 输出样式，list
    '''
    ans = []
    for node in graph_def.node:
        if node.name == name:
            for out in node.attr['_output_shapes'].list.shape:
                shape_string = str(out.dim).replace('\n','').strip()[1:-1]
                if shape_string != '':
                    arr = []
                    for i in shape_string.split(','):
                        if i == '':
                            arr.append('')
                        else:
                            arr.append(int(i.split(':')[1]))
                else:
                    arr = []
                ans.append(arr)
    return ans


def graphdef_mlc_json(graph_def, mlc_feature_dict, sequence):
    '''
    将graph_def转化为mlc，格式为json
    :param graph_def: 存储模型的信息，TensorFlow的GraphDef类型
    :param mlc_feature_dict: 存储需要生成的feature信息 orderedDict
    :param sequence: 存储的特征顺序
    :return: ordered_dict
    '''
    num2name = get_dtype_index_name_dic() # 将dtype转化为对应的字符串
    # Reshape 记下第二个input属性
    reshape_sec_input = []
    for node in graph_def.node:
        if node.op == 'Reshape':
            reshape_sec_input.append((node.input)[1])
    from collections import OrderedDict
    graph_def_order_dict = OrderedDict()
    for node_index, node in enumerate(graph_def.node):
        node_dict = dict()
        op = node.op
        device = node.device
        input_list = []
        for input_name in node.input:
            input_dict = dict()
            input_dict['name'] = input_name
            input_dict['shape'] = getOutputshape(graph_def, input_name)
            input_list.append(input_dict)
        output_list = []
        for out in node.attr['_output_shapes'].list.shape:
            shape_string = str(out.dim).replace('\n', '').strip()[1:-1]
            if shape_string != '':
                arr = []
                for i in shape_string.split(','):
                    if i=='':
                        arr.append('')
                    else:
                        arr.append(int(i.split(':')[1]))
            else:
                arr = []
            output_list.append(arr)
        attr_keys = node.attr.keys()
        attr_values = node.attr.values()
        attrs_dict = dict()
        for index in range(len(attr_keys)):
            if list(attr_keys)[index] == '_output_shapes':
                continue
            if list(attr_values)[index].HasField('tensor'):
                # tensor
                tensor_dict = dict()
                tensor_dict['dtype'] = str(num2name[list(attr_values)[index].tensor.dtype])
                tensor_shape = []
                for size in list(attr_values)[index].tensor.tensor_shape.dim:
                    tensor_shape.append( int( str(size)[5:].replace('\n','').strip() ))
                tensor_dict['tensor_shape'] = tensor_shape
                if len(list(attr_values)[index].tensor.int_val) > 0:
                    tensor_dict['int_val'] = str(list(attr_values)[index].tensor.int_val[0])
                if len(list(attr_values)[index].tensor.tensor_content)>0:
                    if node.name in reshape_sec_input:
                        #Reshape 特殊处理
                        def get_node_tensor_content_str(node):
                            """make tensor pb string"""
                            txt = str(node)
                            key = "tensor_content:"
                            pos = txt.find(key) + len(key)
                            return txt[pos:].split("}")[0].strip()
                        tensor_dict['tensor_content'] = get_node_tensor_content_str(node)
                    else:
                        tensor_dict['tensor_content'] = 'ones'
                attrs_dict[str(list(attr_keys)[index]).strip()] = tensor_dict
            else:
                if len(str(list(attr_values)[index].list)) == 0:
                    # 普通 attr
                    if clean_string(str(list(attr_keys)[index]))=='shape':
                        if op == 'Placeholder':
                            shape_arr_string = list(attr_values)[index].shape
                            if str(shape_arr_string) != '':
                                shape_arr = [int(str(i).split(':')[1]) for i in shape_arr_string.dim]
                                attrs_dict['shape'] = shape_arr
                            else:
                                attrs_dict['shape'] = []
                        else:
                            continue
                    else:
                        not_list_split_arr = str(list(attr_values)[index]).split(':')
                        if len(not_list_split_arr) == 1:
                            attrs_dict[clean_string(str(list(attr_keys)[index]))] = clean_string(str(list(attr_values)[index]))
                        else:
                            attrs_dict[clean_string(str(list(attr_keys)[index]))]= '<'+clean_string(not_list_split_arr[0])+'>'+clean_string(not_list_split_arr[1])
                else:
                    # list
                    list_split_arr = str(list(attr_values)[index].list).split('\n')
                    list_string = '<' + clean_string(list_split_arr[0].split(':')[0]) + '>'
                    list_val = []
                    for val in list_split_arr:
                        if len(val) > 0:
                            list_val.append(clean_string(val.split(':')[1]))
                    list_string += '[' + ','.join(list_val) + ']'
                    attrs_dict[clean_string(str(list(attr_keys)[index]))] = list_string
        node_dict['op'] = op
        node_dict['device'] = device
        node_dict['input'] = input_list
        node_dict['output'] = output_list
        node_dict['attrs'] = attrs_dict
        feature_dict = dict()
        for feature in mlc_feature_dict:
            if type(sequence[node_index])==type(dict()):
                feature_dict[feature] = {'value':sequence[node_index][feature],'unit':mlc_feature_dict[feature]}
        if len(feature_dict) != 0:
            node_dict['feature'] = feature_dict
        graph_def_order_dict[node.name] = node_dict
    return graph_def_order_dict


def transfer_graphdef_o(graph_def, file_dir, file_name, format_name, dic_op_name_shape):
    # graph_def -> .o
    #格式：op:name:device[inputs,][key->value;...]{in{}:out{}}
    # Reshape 操作的 tensor_content 需要特殊处理
    result = []
    name = []
    name_inputs = dict()
    num2name = get_dtype_index_name_dic()
    # Reshape 记下第二个input属性
    reshape_sec_input = []
    for node in graph_def.node:
        if node.op == 'Reshape':
            reshape_sec_input.append((node.input)[1])
    for node in graph_def.node:
        name.append(node.name)
        name_inputs[node.name] = node.input
        node_string = str(node.op)+':'+str(node.name)+':'+str(node.device)+'['+','.join(node.input)+']['
        attr_keys = node.attr.keys()
        attr_values = node.attr.values()
        key_val_list = []
        for index in range(len(attr_keys)):
            key_val_string = clean_string(str(list(attr_keys)[index])) + '->'
            if list(attr_values)[index].HasField('tensor'):
                key_val_string += '{dtype->'+str(num2name[list(attr_values)[index].tensor.dtype])+',tensor_shape->['
                size_list = []
                for size in list(attr_values)[index].tensor.tensor_shape.dim:
                    size_list.append(clean_string(str(size)[5:]))
                key_val_string+=','.join(size_list)+']'
                if len(list(attr_values)[index].tensor.int_val)>0:
                    key_val_string +=',int_val->'+str(list(attr_values)[index].tensor.int_val[0])
                if len(list(attr_values)[index].tensor.tensor_content)>0:
                    if node.name in reshape_sec_input:
                        #Reshape 特殊处理
                        def get_node_tensor_content_str(node):
                            """make tensor pb string"""
                            txt = str(node)
                            key = "tensor_content:"
                            pos = txt.find(key) + len(key)
                            return txt[pos:].split("}")[0].strip()
                        key_val_string += ',tensor_content->'+get_node_tensor_content_str(node)
                    else:
                        key_val_string += ',tensor_content->ones'
                key_val_string+='}'
            else:
                if len(str(list(attr_values)[index].list))==0:
                    not_list_split_arr = str(list(attr_values)[index]).split(':')
                    if len(not_list_split_arr)==1:
                        key_val_string += clean_string(str(list(attr_values)[index]))
                    else:
                        key_val_string += '<'+clean_string(not_list_split_arr[0])+'>'+clean_string(not_list_split_arr[1])
                else:
                    #list
                    list_split_arr = str(list(attr_values)[index].list).split('\n')
                    key_val_string += '<'+clean_string(list_split_arr[0].split(':')[0])+'>'
                    list_val = []
                    for val in list_split_arr:
                        if len(val)>0:
                            list_val.append(clean_string(val.split(':')[1]))
                    key_val_string += '['+ ','.join(list_val) + ']'
            key_val_list.append(key_val_string)
        node_string += ';'.join(key_val_list) + ']'
        result.append(node_string)
    for index in range(len(name)):
        in_node, out_node = find_in_out(index, name, name_inputs)
        for in_index in range(len(in_node)):
            in_node[in_index] = str(in_node[in_index]) + str(dic_op_name_shape[str(in_node[in_index])+':0'])
        for out_index in range(len(out_node)):
            out_node[out_index] = str(out_node[out_index]) + str(dic_op_name_shape[str(name[index])+':0'])
        result[index] += '{in{'+ ','.join(in_node) +'}:out{'+ ','.join(out_node) +'}}\n'
    ## save to file
    path = os.path.abspath(os.path.join(file_dir,file_name+format_name))
    with open(path, 'w+') as f:
        f.write(''.join(result))


def mlc_json_graphdef(json_ordered_dict, emit=None, emit_value=None):
    """
    将读取mlc文件生成的字典转换为graph_def的字符串
    :param json_ordered_dict: mlc文件生成的字典
    :param emit: 使进度条产生变化的函数
    :param emit_value: 进度条的在进入此函数时的初始值
    :return: 生成的graph_def内容字符串
    """
    result = []
    node_num = 1
    for name in json_ordered_dict:
        if name == '':
            continue
        if emit:
            emit('analysis node(%d/%d)'%(node_num,len(json_ordered_dict)),emit_value)
            node_num += 1
        content = 'node {\n'
        content += 'name: "%s"\n'%(name)
        content += 'op: "%s"\n'%(json_ordered_dict[name]['op'])
        input_list = json_ordered_dict[name]['input']
        for input_dict in input_list:
            content += 'input: "%s"\n'%(input_dict['name'])
        content += 'device: "%s"\n'%(json_ordered_dict[name]['device'])

        attrs_dict = json_ordered_dict[name]['attrs']
        attrs_list = []
        outputs_list = json_ordered_dict[name]['output']
        output_shape_list = []
        for output in outputs_list:
            output_shape_content = 'shape {\n'
            for shape_dim in output:
                output_shape_content += 'dim {\nsize: %d\n}\n'%(shape_dim)
            output_shape_content += '}\n'
            output_shape_list.append(output_shape_content)
        output_string = ''.join(output_shape_list)
        attrs_list.append('attr {\nkey: "_output_shapes"\nvalue {\nlist {\n%s}\n}\n}\n'%output_string)
        for key in attrs_dict:
            attrs_content = 'attr {\nkey: "%s"\nvalue {\n'%key
            if type(attrs_dict[key]) == type([]):
                attrs_content += 'list {\nshape {\n'
                for shape_dim in attrs_dict[key]:
                    attrs_content += 'dim {\nsize: %d\n}\n'%shape_dim
                attrs_content += '}\n}\n}\n}\n'
            elif type(attrs_dict[key]) == type(''):
                if attrs_dict[key].startswith('<'):
                    arr = attrs_dict[key].split('>')
                    type_string = arr[0][1:]
                    if arr[1].startswith('['):
                        list_content_arr = eval(arr[1])
                        attrs_content += 'list {\n'
                        for list_dim in list_content_arr:
                            attrs_content += '%s: %d\n'%(type_string,list_dim)
                        attrs_content += '}\n}\n}\n'
                    else:
                        attrs_content += '%s: %s\n}\n}\n'%(type_string,arr[1])
                else:
                    attrs_content += '%s\n}\n}\n'%(attrs_dict[key])
            elif type(attrs_dict[key]) == type(dict()):
                attrs_content += 'tensor {\ndtype: %s\ntensor_shape {\n'%attrs_dict[key]['dtype']
                tensor_shape = attrs_dict[key]['tensor_shape']
                for tensor_shape_dim in tensor_shape:
                    attrs_content += 'dim {\nsize: %d\n}\n'%tensor_shape_dim
                attrs_content +='}\n'
                if 'tensor_content' in attrs_dict[key]:
                    if attrs_dict[key]['tensor_content'] == 'ones':
                        type_np_dict = {}


def transfer_o_pbtxt(url,out_url='',emit=None,emit_value=None):
    ## .o -> .pbtxt
    #格式：op:name:device[inputs,][key->value;...]{in{}:out{}}
    # Reshape 操作的 tensor_content 需要特殊处理
    with codecs.open(url, 'r+', encoding='utf-8') as f:
        node_list = f.readlines()
    result = []
    placeholder_valuelist = []
    opname_arr = []
    node_num = 1
    for node in node_list:
        if node == '':
            continue
        if emit:
            emit('analysis node(%d/%d)'%(node_num,len(node_list)),emit_value)
            node_num+=1
        #确定一行伪指令文件的每个部分的范围
        section_2_start = node.find('[')
        section_2_end = node.find(']',section_2_start)
        section_3_start = section_2_end + 1
        section_4_start = node.find('{in{',section_3_start)
        section_3_end = section_4_start - 1
        section_4_end = len(node) - 2
        #根据确定的范围 截取对应范围的字符串
        section_1 = node[0 : section_2_start]
        section_2 = node[section_2_start+1 : section_2_end]
        section_3 = node[section_3_start+1 : section_3_end]
        section_4 = node[section_4_start+1 : section_4_end]
        content = 'node{\n' #将要生成的pbtxt文件的内容
        z_index = 1 #掌管生成pbtxt生成文件的缩进长度
        #针对每个截取出的部分，根据特定的分隔符得到对应的内容
        basic_attr_list = section_1.split(':',2)
        input_list = section_2.split(',')
        attr_list = section_3.split(';')
        #提取in_list out_list
        in_out_list = section_4.split(':')
        in_list_string = in_out_list[0][3:-1]
        out_list_string = in_out_list[1][4:-1]
        split_check_flag = True
        in_list = []
        in_list_start = 0
        for in_list_string_index in range(len(in_list_string)):
            if in_list_string[in_list_string_index] == '[':
                split_check_flag = False
                continue
            elif in_list_string[in_list_string_index] == ']':
                split_check_flag = True
                continue
            if split_check_flag and in_list_string[in_list_string_index] == ',':
                in_list.append(in_list_string[in_list_start:in_list_string_index])
                in_list_start = in_list_string_index + 1
        if len(in_list_string)!=0 and len(in_list)==0:
            in_list.append(in_list_string)
        out_list = []
        out_list_start = 0
        split_check_flag = True
        for out_list_string_index in range(len(out_list_string)):
            if out_list_string[out_list_string_index] == '[':
                split_check_flag = False
                continue
            elif out_list_string[out_list_string_index] == ']':
                split_check_flag = True
                continue
            if split_check_flag and out_list_string[out_list_string_index] == ',':
                out_list.append(out_list_string[out_list_start:out_list_string_index])
                out_list_start = out_list_start + 1
        if len(out_list_string)!=0 and len(out_list)==0:
            out_list.append(out_list_string)
        #生成pbtxt文件内容
        content += get_indent_string(z_index)+'name: "'+basic_attr_list[1]+'"\n'
        content += get_indent_string(z_index)+'op: "'+basic_attr_list[0]+'"\n'
        #如果此条为placeholder操作，则需要获取其输出样式
        if basic_attr_list[0] == 'Placeholder':
            placeholder_valuelist.append(out_list[0][out_list[0].find('[')+1:-1].split(','))
            opname_arr.append(out_list[0][0:out_list[0].find('[')])
        #生成pbtxt文件 input
        for input_name in input_list:
            if input_name!='':
                content += get_indent_string(z_index)+'input: "'+input_name+'"\n'
        content += get_indent_string(z_index)+'device: "'+basic_attr_list[2]+'"\n'
        #生成pbtxt文件中的attr部分
        for attr in attr_list:
            content += get_indent_string(z_index)+'attr {\n'
            z_index += 1
            key_value = attr.split('->',1)
            if key_value[1][0] == '{':
                #为tensor
                content += get_indent_string(z_index)+'key: "'+key_value[0]+'"\n'
                content += get_indent_string(z_index)+'value {\n'
                z_index += 1
                content += get_indent_string(z_index)+'tensor {\n'
                z_index += 1
                tensor = key_value[1][1:-1]
                tensor_arr = tensor.split('->')
                content += get_indent_string(z_index) + 'dtype: ' + tensor_arr[1].split(',')[0]+'\n'
                dim_list = tensor_arr[2].split(']')[0]
                content += get_indent_string(z_index) + 'tensor_shape {\n'
                z_index += 1
                for dim in dim_list[1:].split(','):
                    if dim!='':
                        content += get_indent_string(z_index)+'dim {\n'
                        z_index += 1
                        content += get_indent_string(z_index)+'size: '+dim+'\n'
                        z_index -= 1
                        content += get_indent_string(z_index)+'}\n'
                z_index -= 1
                content += get_indent_string(z_index)+'}\n'
                if len(tensor_arr)>3:
                    extra_key = tensor_arr[2].split(']')[1][1:]
                    if extra_key == 'int_val':
                        content += get_indent_string(z_index)+extra_key+': '+tensor_arr[3]+'\n'
                    elif extra_key == 'tensor_content':
                        content += get_indent_string(z_index)+extra_key+': '
                        dtype_string = tensor_arr[1].split(',')[0]
                        if dtype_string == 'DT_FLOAT':
                            dtype_type = np.float32
                        elif dtype_string == 'DT_INT32':
                            dtype_type = np.int32
                        #tensor_content_string = tf.make_tensor_proto(np.ones(eval(dim_list+']'),dtype=dtype_type),shape=eval(dim_list+']')).tensor_content.encode('string-escape')
                        if tensor_arr[3] == 'ones':
                            content += get_tensor_content_str(dtype_type,eval(dim_list+']'))+'\n'
                        else:
                            content += tensor_arr[3]+'\n'
            elif key_value[1][key_value[1].find('>')+1]=='[':
                #为list
                content += get_indent_string(z_index)+'key: "'+key_value[0]+'"\n'
                list_type = key_value[1][1:key_value[1].find('>')]
                content += get_indent_string(z_index)+'value {\n'
                z_index += 1
                content += get_indent_string(z_index)+'list {\n'
                z_index += 1
                for list_val in key_value[1][key_value[1].find('[')+1:key_value[1].find(']')].split(','):
                    content += get_indent_string(z_index)+list_type+': '+list_val+'\n'
            else:
                #普通attr
                content += get_indent_string(z_index)+'key: "'+key_value[0]+'"\n'
                content += get_indent_string(z_index)+'value {\n'
                z_index += 1
                if key_value[1].find('>')<0:
                    # Placeholder
                    content += get_indent_string(z_index)+key_value[1]+'\n'
                else:
                    list_type = key_value[1][1:key_value[1].find('>')]
                    list_val = key_value[1][key_value[1].find('>')+1:]
                    content += get_indent_string(z_index)+list_type+': '+list_val+'\n'
            for tail_indent_ci in range(z_index-1):
                z_index -= 1
                content += get_indent_string(z_index)+'}\n'
        z_index-=1
        content += get_indent_string(z_index) +'}\n'
        result.append(content)
    if out_url != '':
        with open(out_url,'w+') as f:
            f.write(''.join(result))
    else:
        return ''.join(result), placeholder_valuelist, opname_arr

def generate_pb_pbtxt(graph_def, file_dir, file_name, file_suffix):
    '''
    Parameters:
        graph_def : graph_def或graph 都可以
        file_dir: 生成文件的路径
        file_name: 生成的文件名
        file_suffix: 目标文件后缀
    '''
    if file_suffix == '.pbtxt':
        tf.train.write_graph(graph_def, file_dir, file_name+file_suffix)
    elif file_suffix == '.pb':
        tf.train.write_graph(graph_def, file_dir, file_name+file_suffix, as_text=False)

def transfer_format(graph_def, file_dir, file_name, file_suffix, dic_op_name_shape=None):
    '''
    Parameters:
        graph_def:
        file_dir: 文件保存路径
        file_name: 保存后的文件名
        file_suffix: 目标文件后缀
    '''
    if file_suffix in ['.pb','.pbtxt']:
        generate_pb_pbtxt(graph_def, file_dir, file_name, file_suffix)
        return True
    elif file_suffix == '.mlc':
        transfer_graphdef_o(graph_def, file_dir, file_name, file_suffix, dic_op_name_shape)
        return True
    return False
